{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Audio data preparation from zero to PyTorch's DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from lhotse import WavAugmenter, CutSet, Fbank, FbankConfig\n",
    "from lhotse.recipes.librispeech import download_and_untar, prepare_librispeech"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting mini LibriSpeech and creating manifests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "download_and_untar('data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "manifests = prepare_librispeech('data/LibriSpeech')\n",
    "train = manifests['train-clean-5']\n",
    "dev = manifests['dev-clean-2']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the Cuts\n",
    "\n",
    "### First, we create the \"starting point\" cut sets - i.e. cuts that actually span full recordings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cuts = CutSet.from_manifests(recordings=train['recordings'], supervisions=train['supervisions'])\n",
    "dev_cuts = CutSet.from_manifests(recordings=dev['recordings'], supervisions=dev['supervisions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We can see the cut durations are far from equal - we'd like to use 5 second long cuts for this experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7f8990e359d0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAASuUlEQVR4nO3de5CddX3H8fe3xAuyNuGiW5pkJkyNdigpSLaI1bYrtE4AY/hDHRyqwaaTGQdb1LQl2mk7znSmUatUZzq0GbHES10pgmSCWmnI1vEP0ASBcLNEjJAViWiILnjL9Ns/zi922d2wZ7Pn5HmW3/s1s3Oe5/dczmfP5nzOc55zSWQmkqRnv19pOoAk6diw8CWpEha+JFXCwpekSlj4klSJBU0HADjllFNy2bJlc97Pk08+yQknnDD3QH3Q5mzQ7nxtzgbtztfmbNDufG3OBp18DzzwwOOZ+aKuN8rMxn9WrlyZvbBjx46e7Kcf2pwts9352pwts9352pwts9352pwts5MP2Jmz6FpP6UhSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiVa8dUKkgSwbOPNjVzv3k0XNXK9x5pH+JJUCQtfkiph4UtSJSx8SapEV4UfEXsjYndE3BkRO8vYSRFxS0Q8WC5PLOMRER+NiD0RcXdEnN3PX0CS1J3ZHOG/JjPPysyhMr8R2J6Zy4HtZR7gAmB5+VkPXN2rsJKkozeXUzprgC1legtw8YTxT5Tv6L8NWBQRp87heiRJPRCZOfNKEd8GDgAJ/Gtmbo6IJzJzUVkewIHMXBQR24BNmfnVsmw7cGVm7py0z/V0ngEwODi4cmRkZM6/zPj4OAMDA3PeTz+0ORu0O1+bs0G787U5G0zNt3vsYCM5VixeOGVsPtx2q1ev3jXhrMuMuv3g1aszcywiXgzcEhEPTFyYmRkRMz9yPH2bzcBmgKGhoRweHp7N5tMaHR2lF/vphzZng3bna3M2aHe+NmeDqfkua+qDV5cOTxmbD7fdbHV1Siczx8rlfuBG4BzgscOnasrl/rL6GLB0wuZLypgkqUEzFn5EnBARLzw8DbwWuAfYCqwtq60FbirTW4G3lnfrnAsczMxHe55ckjQr3ZzSGQRu7JymZwHw75n5pYj4OnBdRKwDvgO8qaz/BeBCYA/wFPC2nqeWJM3ajIWfmQ8BZ04z/gPg/GnGE7i8J+kkST3jJ20lqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEha+JFXCwpekSlj4klQJC1+SKmHhS1IlLHxJqoSFL0mVsPAlqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SapE14UfEcdFxDciYluZPy0ibo+IPRHx2Yh4bhl/XpnfU5Yv6090SdJszOYI/wrg/gnz7weuysyXAAeAdWV8HXCgjF9V1pMkNayrwo+IJcBFwMfKfADnAdeXVbYAF5fpNWWesvz8sr4kqUGRmTOvFHE98A/AC4G/AC4DbitH8UTEUuCLmXlGRNwDrMrMfWXZt4BXZObjk/a5HlgPMDg4uHJkZGTOv8z4+DgDAwNz3k8/tDkbtDtfm7NBu/O1ORtMzbd77GAjOVYsXjhlbD7cdqtXr96VmUPdbrNgphUi4nXA/szcFRHDcwk4UWZuBjYDDA0N5fDw3Hc9OjpKL/bTD23OBu3O1+Zs0O58bc4GU/NdtvHmRnLsvXR4yth8uO1ma8bCB14FvD4iLgSeD/wq8BFgUUQsyMxDwBJgrKw/BiwF9kXEAmAh8INZJ5Mk9dSM5/Az8z2ZuSQzlwGXALdm5qXADuANZbW1wE1lemuZpyy/Nbs5byRJ6qu5vA//SuDdEbEHOBm4poxfA5xcxt8NbJxbRElSL3RzSueXMnMUGC3TDwHnTLPOT4E39iCbJKmH/KStJFXCwpekSlj4klQJC1+SKmHhS1IlLHxJqoSFL0mVsPAlqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEha+JFViQdMBJLXLso03H7Pr2rDiEJcdw+urnUf4klQJC1+SKmHhS1IlLHxJqsSMhR8Rz4+Ir0XEXRFxb0S8r4yfFhG3R8SeiPhsRDy3jD+vzO8py5f191eQJHWjmyP8nwHnZeaZwFnAqog4F3g/cFVmvgQ4AKwr668DDpTxq8p6kqSGzVj42TFeZp9TfhI4D7i+jG8BLi7Ta8o8Zfn5ERE9SyxJOipdncOPiOMi4k5gP3AL8C3gicw8VFbZBywu04uBRwDK8oPAyb0MLUmavcjM7leOWATcCPwNcG05bUNELAW+mJlnRMQ9wKrM3FeWfQt4RWY+Pmlf64H1AIODgytHRkbm/MuMj48zMDAw5/30Q5uzQbvztTkbtDvf0WTbPXawT2mmGjweHvvJMbu6I1qxeOGUsTb/XaGTb/Xq1bsyc6jbbWb1SdvMfCIidgCvBBZFxIJyFL8EGCurjQFLgX0RsQBYCPxgmn1tBjYDDA0N5fDw8GyiTGt0dJRe7Kcf2pwN2p2vzdmg3fmOJtux/OTrhhWH+NDu5j/wv/fS4Sljbf67QiffbHXzLp0XlSN7IuJ44I+A+4EdwBvKamuBm8r01jJPWX5rzuZphCSpL7p5aD0V2BIRx9F5gLguM7dFxH3ASET8PfAN4Jqy/jXAJyNiD/BD4JI+5JYkzdKMhZ+ZdwMvn2b8IeCcacZ/CryxJ+kkST3jJ20lqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEha+JFXCwpekSlj4klQJC1+SKmHhS1IlLHxJqoSFL0mVsPAlqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SarEjIUfEUsjYkdE3BcR90bEFWX8pIi4JSIeLJcnlvGIiI9GxJ6IuDsizu73LyFJmlk3R/iHgA2ZeTpwLnB5RJwObAS2Z+ZyYHuZB7gAWF5+1gNX9zy1JGnWZiz8zHw0M+8o0z8G7gcWA2uALWW1LcDFZXoN8InsuA1YFBGn9jy5JGlWIjO7XzliGfAV4Azg4cxcVMYDOJCZiyJiG7ApM79alm0HrszMnZP2tZ7OMwAGBwdXjoyMzPmXGR8fZ2BgYM776Yc2Z4N252tzNmh3vqPJtnvsYJ/STDV4PDz2k2N2dUe0YvHCKWNt/rtCJ9/q1at3ZeZQt9ss6HbFiBgAPge8MzN/1On4jszMiOj+kaOzzWZgM8DQ0FAODw/PZvNpjY6O0ov99EObs0G787U5G7Q739Fku2zjzf0JM40NKw7xod1d11Df7L10eMpYm/+u0Mk3W129SycinkOn7D+dmTeU4ccOn6opl/vL+BiwdMLmS8qYJKlB3bxLJ4BrgPsz88MTFm0F1pbptcBNE8bfWt6tcy5wMDMf7WFmSdJR6Oa51KuAtwC7I+LOMvZeYBNwXUSsA74DvKks+wJwIbAHeAp4W08TS5KOyoyFX158jSMsPn+a9RO4fI65JEk95idtJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEs1/TZ0kNWzZNN8QumHFoWPyzaF7N13U9+s4zCN8SaqEhS9JlbDwJakSFr4kVcIXbaWWmu6FxNk6Vi88an7wCF+SKmHhS1IlLHxJqoSFL0mVsPAlqRIWviRVwsKXpEpY+JJUCQtfkiph4UtSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVIkZCz8iPh4R+yPingljJ0XELRHxYLk8sYxHRHw0IvZExN0RcXY/w0uSutfNEf61wKpJYxuB7Zm5HNhe5gEuAJaXn/XA1b2JKUmaqxkLPzO/Avxw0vAaYEuZ3gJcPGH8E9lxG7AoIk7tVVhJ0tGLzJx5pYhlwLbMPKPMP5GZi8p0AAcyc1FEbAM2ZeZXy7LtwJWZuXOafa6n8yyAwcHBlSMjI3P+ZcbHxxkYGJjzfvqhzdmg3fnanA36l2/32ME572PweHjsJz0I0ydtznessq1YvPCothsfH2f16tW7MnOo223m/J+YZ2ZGxMyPGlO32wxsBhgaGsrh4eG5RmF0dJRe7Kcf2pwN2p2vzdmgf/l68Z+Pb1hxiA/tnvPdvG/anO9YZdt76fBRbTc6OjrrbY72XTqPHT5VUy73l/ExYOmE9ZaUMUlSw4628LcCa8v0WuCmCeNvLe/WORc4mJmPzjGjJKkHZny+EhGfAYaBUyJiH/B3wCbguohYB3wHeFNZ/QvAhcAe4CngbX3ILEk6CjMWfma++QiLzp9m3QQun2soSVLvtfPVEqkllnXxwumGFYd68gKr1G9+tYIkVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEn7wSvNCNx+AkvTMPMKXpEp4hD8Hsznq7OXH7/duuqgn+5FUl3lf+BNL1+806b8mTq1sWHGIZ8E/ValxntKRpEpY+JJUCQtfkirhidF5qB/n0X39Q3r28whfkiph4UtSJSx8SaqEhS9JlbDwJakSFr4kVcLCl6RKWPiSVAkLX5IqYeFLUiUsfEmqhIUvSZWw8CWpEha+JFWiL4UfEasi4psRsSciNvbjOiRJs9Pzwo+I44B/Bi4ATgfeHBGn9/p6JEmz048j/HOAPZn5UGb+HBgB1vTheiRJsxCZ2dsdRrwBWJWZf1rm3wK8IjPfMWm99cD6Mvsy4Js9uPpTgMd7sJ9+aHM2aHe+NmeDdudrczZod742Z4NOvhMy80XdbtDYf3GYmZuBzb3cZ0TszMyhXu6zV9qcDdqdr83ZoN352pwN2p2vzdngl/mWzWabfpzSGQOWTphfUsYkSQ3qR+F/HVgeEadFxHOBS4CtfbgeSdIs9PyUTmYeioh3AP8JHAd8PDPv7fX1HEFPTxH1WJuzQbvztTkbtDtfm7NBu/O1ORscRb6ev2grSWonP2krSZWw8CWpEvO+8CNiaUTsiIj7IuLeiLii6UzTiYjjIuIbEbGt6SwTRcSiiLg+Ih6IiPsj4pVNZ5ooIt5V/q73RMRnIuL5Def5eETsj4h7JoydFBG3RMSD5fLEFmX7YPnb3h0RN0bEoiayHSnfhGUbIiIj4pQ2ZYuIPyu3370R8YEmsh0pX0ScFRG3RcSdEbEzIs6ZaT/zvvCBQ8CGzDwdOBe4vKVf5XAFcH/TIabxEeBLmfmbwJm0KGNELAb+HBjKzDPovAngkmZTcS2watLYRmB7Zi4Htpf5JlzL1Gy3AGdk5m8D/wO851iHmuBapuYjIpYCrwUePtaBJriWSdki4jV0viXgzMz8LeAfG8h12LVMve0+ALwvM88C/rbMP6N5X/iZ+Whm3lGmf0ynsBY3m+rpImIJcBHwsaazTBQRC4HfB64ByMyfZ+YTzaaaYgFwfEQsAF4AfLfJMJn5FeCHk4bXAFvK9Bbg4mMaqpguW2Z+OTMPldnb6HwuphFHuO0ArgL+CmjsHSRHyPZ2YFNm/qyss/+YByuOkC+BXy3TC+nivjHvC3+iiFgGvBy4vdkkU/wTnX/Q/9t0kElOA74P/Fs53fSxiDih6VCHZeYYnaOqh4FHgYOZ+eVmU01rMDMfLdPfAwabDPMM/gT4YtMhJoqINcBYZt7VdJZpvBT4vYi4PSL+OyJ+p+lAk7wT+GBEPELnfjLjs7dnTeFHxADwOeCdmfmjpvMcFhGvA/Zn5q6ms0xjAXA2cHVmvhx4kuZOR0xRzoWvofPA9OvACRHxx82membZeZ9z697rHBF/Tef056ebznJYRLwAeC+d0xFttAA4ic6p4r8ErouIaDbS07wdeFdmLgXeRXmm/kyeFYUfEc+hU/afzswbms4zyauA10fEXjrfHHpeRHyq2Ui/tA/Yl5mHnxFdT+cBoC3+EPh2Zn4/M38B3AD8bsOZpvNYRJwKUC4be+o/nYi4DHgdcGm264M3v0Hnwfyucv9YAtwREb/WaKr/tw+4ITu+RucZeiMvKh/BWjr3CYD/oPNNxc9o3hd+ecS9Brg/Mz/cdJ7JMvM9mbmkfMnRJcCtmdmKo9TM/B7wSES8rAydD9zXYKTJHgbOjYgXlL/z+bToReUJttK581Eub2owy9NExCo6pxNfn5lPNZ1noszcnZkvzsxl5f6xDzi7/Ltsg88DrwGIiJcCz6Vd3575XeAPyvR5wIMzbpGZ8/oHeDWdp9B3A3eWnwubznWErMPAtqZzTMp0FrCz3H6fB05sOtOkfO8DHgDuAT4JPK/hPJ+h83rCL+gU1DrgZDrvznkQ+C/gpBZl2wM8MuG+8S9tuu0mLd8LnNKWbHQK/lPl394dwHltuu1K9+0C7qLzuuXKmfbjVytIUiXm/SkdSVJ3LHxJqoSFL0mVsPAlqRIWviRVwsKXpEpY+JJUif8Dv3MM4zXNs7gAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "pd.Series(c.duration for c in train_cuts).hist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We can cut the longer recordings into 5 second cuts by traversing them in windows; the left-over portion of the recording might still be shorter, so we will pad it with silence to 5 seconds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1519, 4571)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_cuts_filt = train_cuts.cut_into_windows(5).pad(5)\n",
    "assert all(cut.duration == 5 for cut in train_cuts_filt)\n",
    "len(train_cuts), len(train_cuts_filt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's create a simple dataset for our very specific task - classification whether an audio clip has been reverberated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "\n",
    "class ReverbDetectionDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, cuts):\n",
    "        self.cuts = cuts\n",
    "        self.cut_ids = list(cuts.ids)\n",
    "        self.extractor = Fbank(FbankConfig(num_mel_bins=80))\n",
    "        self.augmenter = WavAugmenter.create_predefined('reverb', cuts[0].sampling_rate)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        cut = self.cuts[self.cut_ids[idx]]\n",
    "        augment = random.choice([0, 1])\n",
    "        feats = cut.compute_features(self.extractor, self.augmenter if augment else None)\n",
    "        return torch.from_numpy(feats), float(augment)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.cuts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the DataLoader is very simple - no collate_fn is needed at all, since we used CutSet's capabilities to bring the data to equal length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset = ReverbDetectionDataset(train_cuts_filt)\n",
    "val_dset = ReverbDetectionDataset(dev_cuts)\n",
    "\n",
    "train_dloader = torch.utils.data.DataLoader(train_dset, batch_size=8, shuffle=True, num_workers=2)\n",
    "val_dloader = torch.utils.data.DataLoader(val_dset, batch_size=1, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training DataLoader shapes:\n",
      "torch.Size([8, 500, 80]) torch.Size([8])\n",
      "tensor([1., 1., 0., 1., 0., 1., 0., 0.], dtype=torch.float64)\n",
      "Dev DataLoader shapes:\n",
      "torch.Size([1, 1166, 80]) torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "for feats, targets in train_dloader:\n",
    "    print('Training DataLoader shapes:')\n",
    "    print(feats.shape, targets.shape)\n",
    "    print(targets)\n",
    "    break\n",
    "    \n",
    "for feats, targets in val_dloader:\n",
    "    print('Dev DataLoader shapes:')\n",
    "    print(feats.shape, targets.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}